home *** CD-ROM | disk | FTP | other *** search
- /* =======================================================
- Neural Network Classes for the NeXT Computer
- Written by: Ralph Zazula
- University of Arizona - Fall 1991
- zazula@pri.com (NeXT Mail)
- ==========================================================*/
- /*$Log: Neuron.m,v $
- Revision 1.4 92/01/14 21:19:46 zazula
- Check in before starting HashTable mod
-
- Revision 1.3 92/01/02 14:04:31 zazula
- Faster linked-list for connections
- No more Storage object
-
- Revision 1.2 92/01/02 12:41:34 zazula
- Initial version - support for stochastic networks via temperature T
- */
- #import "Neuron.h"
- #import <appkit/nextstd.h>
- #import "math.h"
-
-
- //----------------------------------------------------------
-
- @implementation Neuron
-
- - inputs { return inputs; }
- - setType:(int)type { nodeType = type; return self; }
- - (int)getType { return nodeType; }
- - setTemp:(double)newT { T = newT; return self; }
- - (double)getTemp { return T; }
- - setRandom:theRandom { random = theRandom; return self; }
- - setSymmetric:(BOOL)sym { Symmetric = sym; return self; }
- - (BOOL)getSymmetric { return Symmetric; }
-
- //----------------------------------------------------------
-
- - (double)activation:(double)net
- {
- double temp;
-
- if(random == nil) random = [[Random alloc] init];
- switch (nodeType) {
- case Binary :
- if(T > 0.0)
- temp = ([random percent] <= 1.0/(1.0+exp(-2*net/T))) ? 1.0 : 0.0;
- else
- temp = (net > 0.5) ? 1.0 : 0.0;
- break;
- case Sigmoid :
- temp = 1.0/(1.0+exp(-net));
- break;
- case Sign :
- if(T > 0.0)
- temp = ([random percent] <= 1.0/(1.0+exp(-2*net/T))) ? 1.0 : -1.0;
- else
- temp = (net > 0.0) ? 1.0 : -1.0;
- break;
- case Tanh :
- if(T > 0.0)
- temp = tanh(net/T);
- else
- temp = tanh(net);
- break;
- }
-
- return temp;
- }
-
- //----------------------------------------------------------
-
- - init
- {
- [super init];
- lastOutput = 0.0;
- nodeType = Sigmoid; // default node type
- T = 0.0; // default temperature
- head = tail = NULL; // initialize the linked-list of connections
- Symmetric = NO; // default Symmetric connection status
-
- return self;
- }
-
- //-----------------------------------------------------------
-
- - step
- // update the output value based on our inputs
- {
- int i = 0;
- connection *C;
- double temp=0.0; // use temp variable to allow for feedback
-
- C = head;
- while(C != NULL) {
- temp += C->weight*[C->source lastOutput];
- C = (connection *)C->next;
- }
-
- lastOutput = [self activation:temp];
-
- return self;
- }
-
- //-----------------------------------------------------------
-
- - (double)lastOutput
- {
- return lastOutput;
- }
-
- //-----------------------------------------------------------
-
- - connect:sender
- {
- if(random == nil) random = [[Random alloc] init];
- return [self connect:sender withWeight:[random percent]/10.0];
- }
-
- //-----------------------------------------------------------
-
- - connect:sender withWeight:(double)weight
- //
- // adds sender to the list of inputs
- // we should check to make sure sender is a Neruon
- // also need to check if it is already in the list
- //
- {
- connection *C;
-
- C = (connection *)malloc(sizeof(connection));
- if(head == NULL) {
- head = C;
- }
- else {
- tail->next = C;
- }
- tail = C;
- C->source = sender;
- C->weight = weight;
- C->next = NULL;
-
- return self;
- }
-
- //-----------------------------------------------------------
-
- - (double)getWeightFor:source
- {
- int i=0;
- connection *C;
-
- C = head;
- while((C != NULL) && (C->source != source))
- C = (connection *)C->next;
-
- if(C != NULL) { // if C==NULL, source isn't an input
- return C->weight;
- }
- else {
- fprintf(stderr,"connection not found in getWeightFor:\n");
- return NAN;
- }
-
- }
-
- //-----------------------------------------------------------
-
- - setWeightFor:source to:(double)weight
- {
- int i=0;
- connection *C;
-
- C = head;
- while((C != NULL) && (C->source != source))
- C = (connection *)C->next;
-
- if(C != NULL) { // if C==NULL, source isn't an input
- C->weight = weight;
- return self;
- }
- else {
- fprintf(stderr,"connection not found in setWeightFor:to:\n");
- return nil;
- }
- }
-
- //-----------------------------------------------------------
-
- - setOutput:(double)output
- {
- lastOutput = output;
-
- return self;
- }
- //-----------------------------------------------------------
-
- - changeWeightFor:source by:(double)delta
- {
- int i=0;
- connection *C;
-
- C = head;
- while((C != NULL) && (C->source != source))
- C = (connection *)C->next;
-
- if(C != NULL) { // if C==NULL, source isn't an input
- C->weight += delta;
- if(Symmetric) // for symmetric connections
- [source setWeightFor:self to:C->weight];
- return self;
- }
- else {
- fprintf(stderr,"connection not found in changeWeightfor:by:\n");
- // printf("connection not found in changeWeightfor:by:\n");
- return nil;
- }
- }
-
-
- @end
-